import numpy
import matplotlib.pyplot as pyplot
import scipy.optimize as optimise
import os
from jqc import jqc_plot
import scipy.constants as constants
from matplotlib import gridspec as gridspec
import scipy.integrate
from matplotlib.ticker import ScalarFormatter,LogFormatterSciNotation,\
                                LogFormatter,LogFormatterExponent
###############################################################################
# script constants
###############################################################################

#Define some shorthand for constants from scipy
c=constants.c
pi = numpy.pi
h = constants.h
hbar = constants.hbar
u = constants.u
kB = constants.Boltzmann

mRb = 87.
mCs = 133.

m=(mRb+mCs)*u

T = 1.5e-6

mu = (mRb*mCs)/(mRb+mCs)
mu = mu * u

#Define some conversion factors
conv_waveno_joule = h*c*100 # from 1/cm to J
waveno_to_GHz = 29.9792458 #speed of light in cm per ns


JQC = jqc_plot.colours

cwd = os.path.dirname(os.path.abspath(__file__))

jqc_plot.plot_style('normal')

grid = gridspec.GridSpec(2,2,width_ratios=[0.8,1],height_ratios=[1,1])

fig = pyplot.figure("Fig2")

##############################################################################
#Function Definitions
##############################################################################
def GetAveragedData(Data):
    #Sort data by time (first) column
    Data = Data[numpy.argsort(Data[:,0])]

    SectDat = Data[0,:] #Initilise arrays to store future data
    AveragedData = numpy.zeros(3)
    for i in range(1, len(Data[:,0])):
        if Data[i,0] == Data[i-1,0]:
            SectDat = numpy.vstack((SectDat, Data[i,:])) #Array with all same times
        else:
            #When time is about to change, average number over that time
            AverageN = numpy.average(SectDat[:,1])
            ErrN = numpy.std(SectDat[:,1])/numpy.sqrt(len(SectDat[:,1]))

            AveragedData = numpy.vstack((AveragedData,
			 					numpy.array([SectDat[0,0], AverageN, ErrN])))
            SectDat = Data[i,:] #Reinitialise same time data store
    #Need to get data from last timestep
    AverageN = numpy.average(SectDat[:,1])
    ErrN = numpy.std(SectDat[:,1])/numpy.sqrt(len(SectDat[:,1]))

    AveragedData = numpy.vstack((AveragedData,
	 							numpy.array([SectDat[0,0], AverageN, ErrN])))
	#Delete initial row of zeros
    AveragedData = numpy.delete(AveragedData, 0, axis=0)
    return AveragedData
def TwoBodyBkgd(Input, t, a, b):
    N, T = Input
    a = abs(a)
    b = abs(b)
    dN = -N*a - b*(N**2)/(T**(3/2))
    dT = (b/4)*(N/(numpy.sqrt(T)))
    Output = (dN, dT)
    return Output

def Calcb(ValueRaw):
    Value = ValueRaw/(((4*numpy.pi*kB)/(m*(w**2)))**1.5) #m^3 s^-1
    return Value*1e-6

def TwoBodyBkgdFitFunction(xData, N0, a, b):
    #Ndat = 1e3
    Init = N0, T0
    Times = numpy.zeros(1)
    for i in range(1, len(xData[:])):
        if xData[i] != xData[i-1]:
            StarT = abs(xData[i] - xData[i-1])/Ndat
            Tim = numpy.linspace(xData[i-1]+StarT, xData[i], Ndat)
            Times = numpy.hstack((Times, Tim))

    #CHANGE THIS LINE TO CHANGE FIT FUNCTION!!
    FitLine = scipy.integrate.odeint(TwoBodyBkgd, Init, t = Times, args=(a,b))

    yData = FitLine[0::int(Ndat), 0]
    yData = numpy.array(FitLine[0, 0])
    it = 0
    for i in range(1, len(xData[:])):
        if xData[i] != xData[i-1]:
            it = it+1
        yData = numpy.hstack((yData, FitLine[int(it*Ndat), 0]))
    return yData

###############################################################################
# Plotting the intensity dependence
###############################################################################

ax1 = fig.add_subplot(grid[1,1])

Data = numpy.genfromtxt(cwd+"\\data\\data_intensity_fixedK2.csv",
                        delimiter=',')[:,:]


'''Fitting a linear function to the data imported above'''
fitfn1 = lambda x,A,B: B*x**1
curve1,err1 = optimise.curve_fit(fitfn1,Data[:,0],Data[:,1],sigma=Data[:,2],
                                    absolute_sigma=True)
err1= numpy.sqrt(numpy.diag(err1))
chi2 = numpy.sum((Data[:,1]-fitfn1(Data[:,0],*curve1))**2/(Data[:,2])**2)/(len(Data[:,0])-1.)

print("linear:",chi2)

'''Fitting a quadratic'''
fitfn2 = lambda x,B: B*x**2
curve2,err2 = optimise.curve_fit(fitfn2,Data[:-1,0],Data[:-1,1],
sigma=Data[:-1,2],absolute_sigma=True)

err2 = numpy.sqrt(numpy.diag(err2))

chi2 = numpy.sum((Data[:,1]-fitfn2(Data[:,0],*curve2))**2/(Data[:,2])**2)/(len(Data[:,0])-1.)

print("quad:",chi2)
'''Fitting a cubic'''
fitfn3 = lambda x,B: B*x**3
curve3,err3 = optimise.curve_fit(fitfn3,Data[:,0],Data[:,1],sigma=Data[:,2],absolute_sigma=True)

err3= numpy.sqrt(numpy.diag(err3))

chi2 = numpy.sum((Data[:,1]-fitfn3(Data[:,0],*curve3))**2/(Data[:,2])**2)/(len(Data[:,0])-1.)

print("cube:",chi2)
'''Fitting an arbitrary exponent'''
fitfna = lambda x,B,C: B*x**C
curvea,erra = optimise.curve_fit(fitfna,Data[:,0],Data[:,1],sigma=Data[:,2],absolute_sigma=True)

erra= numpy.sqrt(numpy.diag(erra))

chi2 = numpy.sum((Data[:,1]-fitfna(Data[:,0],*curvea))**2/(Data[:,2])**2)/(len(Data[:,0])-2.)


print("arb:",chi2)

'''plotting on a log10 scale'''

ax1.set_xscale("log", nonposx='clip')
ax1.set_yscale("log", nonposy='clip')

DataRange = numpy.linspace(0,2,150)


ax1.errorbar(Data[:-1,0],Data[:-1,1],yerr=Data[:-1,2],xerr=Data[:-1,0]*0.0084,
                fmt='o',color='k',zorder=1.8,capsize=3)


ax1.text(0.15,6,"$\\kappa$=1",color=JQC['blue'])
ax1.plot(DataRange,fitfn1(DataRange,*curve1),color=JQC['blue'],label='$k$=1',
        zorder=1.1,ls='--')

fitted = fitfna(DataRange,*curvea)
fitted_err1 = ((curvea[0]+erra[0])*DataRange**curvea[1])-fitted
fitted_err2 = (curvea[0]*DataRange**(curvea[1]+erra[1]))-fitted
print(curvea,erra)
fitted_err = numpy.sqrt(fitted_err1**2+fitted_err2**2)


ax1.plot(DataRange,fitted,color=jqc_plot.colours['red'],zorder=1.05)
ax1.fill_between(DataRange,fitted-fitted_err,fitted+fitted_err,
                    color=jqc_plot.colours["reddish"],zorder=1.02,alpha=.5,facecolor='k')


ax1.set_ylabel("Loss Coefficient, \n $k_\\mathrm{l}$ (s$^{-1}$)")
ax1.set_xlabel("Intensity (kW$\\,$cm$^{-2}$)")

ax1.set_ylim(1.4,200)
ax1.set_xlim(0.1,2.0)

ax1.set_xticks([0.1,1])
ax1.set_xticks([0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,2],minor =True)
ax1.set_xticklabels(["0.1","1"])
ax1.set_xticklabels(["","","","0.5","","","","","2"],minor=True)

print("exponent:",curvea[1],erra[1])
print("coefficient:",curvea[0],erra[0])

###############################################################################
####################### Plot the density dependence ###########################
###############################################################################
#Trap frequencies for these measurements
w_x = 2*pi*181
w_y = 2*pi*44
w_z = 2*pi*178

#left, bottom, width, height
ax2 = fig.add_subplot(grid[0,1])

#ax2.set_xscale("log", nonposx='clip')
#ax2.set_yscale("log", nonposy='clip')

Data = numpy.genfromtxt(cwd+"\\data\\data_density[exp].csv",delimiter=',')

Data = numpy.genfromtxt(cwd+"\\data\\data_density[linear].csv",delimiter=',')

DataRange = numpy.linspace(0,2,150)

Data[:,0:2] = 1e-6*Data[:,0:2]*w_x*w_y*w_z*(m/(2*pi*kB*T))**1.5

#Data[:,2:] = Data[:,2:]*1e3

ax2.errorbar(Data[:,0]*1e-11,Data[:,2],yerr=Data[:,3],xerr=Data[:,1]*1e-11,fmt='o',color='k',zorder=1.8,capsize=3)

#ax2.scatter(numpy.log10(Data[:,0]),numpy.log10(Data[:,2]))

avg,err= numpy.average(Data[:,2],weights=1/(Data[:,3])**2,returned=True)
err = err**(-0.5)
print(avg,err)

ax2.axhspan(avg-err,avg+err,color=JQC['grayblue'],zorder=1.2)

ax2.set_ylabel("Fractional \n Loss Rate (ms$^{-1}$)")
ax2.set_xlabel("Initial Density,\n $n_0$ (10$^{11}\\,$cm$^{-3}$)")


ax2.set_xlim(0,2)
ax2.set_ylim(0,0.06)

xmin,xmax = ax2.get_xlim()
ax2.yaxis.set_major_formatter(ScalarFormatter())
ax2.plot([xmin,xmax],[avg,avg],color=JQC['blue'],zorder=1.3,ls='--')
plot_density = numpy.linspace(xmin,xmax,500)

curve,cov = optimise.curve_fit(lambda x,A: x*A,Data[:,0]*1e-11,Data[:,2],p0=[9])
print(curve,numpy.sqrt(cov))
ax2.plot(plot_density,curve[0]*plot_density,ls='--',color=jqc_plot.colours['red'],zorder=1.25)
ax2.text(1.5,0.05,"$\\propto n_0$",color=jqc_plot.colours['red'])


##############################################################################
#third subplot, lifetimes
##############################################################################

# Trap frequency for these measurements
w = 2*numpy.pi*(174.8*174.8*42.8)**(1/3) #Geometric average of trap frequency,Hz

colors=[JQC['red'],JQC['blue'],JQC['green']]
colors2 = [JQC['reddish'],JQC['grayblue'],JQC['greenish']]


ax3 = fig.add_subplot(grid[:,0])

time_line = numpy.linspace(0,1600,1600)

fitfn = lambda x,A,B: A*numpy.exp(-x/B)

x=[126,200,150]
y=[1.8,1.05,0.1]
R = [-45,-35,0]

files =[0,1,4]

pars = numpy.genfromtxt(cwd+"\\data\\Lifetimes\\Lifetimes_pars.csv",
							delimiter=',')

for i,intensity in enumerate(["0.0","0.3","1.2"]):
	data = numpy.genfromtxt(cwd+"\\data\\Lifetimes\\I="+intensity+".csv",
							delimiter=',')

	AveragedData = GetAveragedData(data)
	#curve,cov = optimize.curve_fit(fitfn,data[:,0],data[:,1],p0=[1500,50])
	p=pars[files[i],:]

	N0,Nerr,a,aerr,b,berr = p

	b = Calcb(b)

	Init=N0,1.5e-6

	FitTimes = numpy.linspace(0, 1.8, int(1e6))

	Fit2BodyLine = scipy.integrate.odeint(TwoBodyBkgd, Init, t = FitTimes,
											args=(a,b))

	ax3.errorbar(AveragedData[:,0]*1e3,
				AveragedData[:,1]/(Fit2BodyLine[0,0]*0.83),
				yerr=AveragedData[:,2]/(Fit2BodyLine[0,0]*0.83),
				color=colors[i],fmt='o',
				zorder=1.2,label=intensity+" kW$\\,$cm$^{-2}$",
				capsize=3.5,ms=5)

	ax3.plot(FitTimes*1e3, (Fit2BodyLine[:,0]/Fit2BodyLine[0,0]),
			color=colors2[i],
				zorder=1.1)

data = numpy.genfromtxt(cwd+"\\data\\Lifetimes\\Offresonant.csv",
						delimiter=',')

AveragedData = GetAveragedData(data)

p=pars[-1,:]
print(p)
N0,Nerr,a,aerr,b,berr = p
w = 2*numpy.pi*(170*170*170)**(1/3) #Geometric average of trap frequency,Hz
b = Calcb(b)

Init=N0,2.6e-6

FitTimes = numpy.linspace(0, 1.8, int(1e6))

Fit2BodyLine = scipy.integrate.odeint(TwoBodyBkgd, Init, t = FitTimes,
										args=(a,b))
AveragedData[:,1]=AveragedData[:,1]/0.89
ax3.errorbar(AveragedData[:,0],AveragedData[:,1]/(Fit2BodyLine[0,0]),
				yerr=AveragedData[:,2]/(Fit2BodyLine[0,0]),ecolor='k',fmt='o',
				zorder=1.2,capsize=3.5,ms=5,markerfacecolor = 'w',
				markeredgecolor='k')


ax3.plot(FitTimes*1e3, (Fit2BodyLine[:,0]/Fit2BodyLine[0,0]),color='k',
			zorder=1.1,ls='--')

ax3.legend(fontsize=14,handletextpad=0.0,frameon=False,
            loc=(0.1,0.65),markerscale = 0.75)
ax3.set_xlim(-20,550)
ax3.set_ylim(-0.05,1.25)
ax3.set_xlabel("Time (ms)")
ax3.set_ylabel("Fraction of Molecules Remaining")

###############################################################################
# Finishing touches, labels/ text etc.
###############################################################################

ax1.text(0.03,0.8,"(c)",transform=ax1.transAxes,fontsize=20)
ax2.text(0.03,0.8,"(b)",transform=ax2.transAxes,fontsize=20)
ax3.text(0.03,0.93,"(a)",transform=ax3.transAxes,fontsize=20)




fig.tight_layout()
fig.subplots_adjust(hspace=0.65,wspace=0.58,top=0.95,bottom=0.16,left=0.13,
                    right=0.97)

fig.savefig(cwd+"\\OUTPUT\\IntensityandDensityV2.pdf")
fig.savefig(cwd+"\\OUTPUT\\IntensityandDensityV2.png")


pyplot.show()
